{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.8 网络中的网络（NiN）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4.0\n",
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "import d2lzh_pytorch as d2l\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "print(torch.__version__)\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.8.1 NiN块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nin_block(in_channels, out_channels, kernel_size, stride, padding):\n",
    "    blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),\n",
    "                        nn.ReLU(),\n",
    "                        nn.Conv2d(out_channels, out_channels, kernel_size=1),\n",
    "                        nn.ReLU(),\n",
    "                        nn.Conv2d(out_channels, out_channels, kernel_size=1),\n",
    "                        nn.ReLU())\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.8.2 NiN模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "    nin_block(1, 96, kernel_size=11, stride=4, padding=0),\n",
    "    nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "    nin_block(96, 256, kernel_size=5, stride=1, padding=2),\n",
    "    nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "    nin_block(256, 384, kernel_size=3, stride=1, padding=1),\n",
    "    nn.MaxPool2d(kernel_size=3, stride=2), \n",
    "    nn.Dropout(0.5),\n",
    "    # 标签类别数是10\n",
    "    nin_block(384, 10, kernel_size=3, stride=1, padding=1),\n",
    "    # 全局平均池化层可通过将窗口形状设置成输入的高和宽实现\n",
    "    nn.AvgPool2d(kernel_size=5),\n",
    "    # 将四维的输出转成二维的输出，其形状为(批量大小, 10)\n",
    "    d2l.FlattenLayer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 output shape:  torch.Size([1, 96, 54, 54])\n",
      "1 output shape:  torch.Size([1, 96, 26, 26])\n",
      "2 output shape:  torch.Size([1, 256, 26, 26])\n",
      "3 output shape:  torch.Size([1, 256, 12, 12])\n",
      "4 output shape:  torch.Size([1, 384, 12, 12])\n",
      "5 output shape:  torch.Size([1, 384, 5, 5])\n",
      "6 output shape:  torch.Size([1, 384, 5, 5])\n",
      "7 output shape:  torch.Size([1, 10, 5, 5])\n",
      "8 output shape:  torch.Size([1, 10, 1, 1])\n",
      "9 output shape:  torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand(1, 1, 224, 224)\n",
    "\n",
    "for name, blk in net.named_children(): \n",
    "    X = blk(X)\n",
    "    print(name, 'output shape: ', X.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.8.3 获取数据和训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.0101, train acc 0.513, test acc 0.734, time 260.9 sec\n",
      "epoch 2, loss 0.0050, train acc 0.763, test acc 0.754, time 175.1 sec\n",
      "epoch 3, loss 0.0041, train acc 0.808, test acc 0.826, time 151.0 sec\n",
      "epoch 4, loss 0.0037, train acc 0.828, test acc 0.827, time 151.0 sec\n",
      "epoch 5, loss 0.0034, train acc 0.839, test acc 0.831, time 151.0 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 128\n",
    "# 如出现“out of memory”的报错信息，可减小batch_size或resize\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n",
    "\n",
    "lr, num_epochs = 0.002, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
